{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "e2d63327-b7da-4060-8664-5cf1072d4cdb",
   "metadata": {},
   "source": [
    "# Generating Synthetic Wideband Spectrogram-Like Images\n",
    "In this notebook we will use TorchSig's `image_datasets` module to generate synthetic spectrograms of a few modeled signal types.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d747c181-067c-46a7-9deb-650c9bfff3c2",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d0fa212e-6128-495f-8af6-4bba4629fc76",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from torchsig.image_datasets.datasets.synthetic_signals import GeneratorFunctionDataset, rectangle_signal_generator_function, tone_generator_function, repeated_signal_generator_function, chirp_generator_function\n",
    "from torchsig.image_datasets.datasets.file_loading_datasets import SOIExtractorDataset, LazyImageDirectoryDataset\n",
    "from torchsig.image_datasets.datasets.composites import ConcatDataset\n",
    "from torchsig.image_datasets.datasets.yolo_datasets import YOLOImageCompositeDataset, YOLODatasetAdapter, YOLODatum\n",
    "from torchsig.image_datasets.transforms.impairments import BlurTransform, RandomGaussianNoiseTransform, RandomImageResizeTransform, RandomRippleNoiseTransform, ScaleTransform, scale_dynamic_range, normalize_image\n",
    "from torchsig.image_datasets.plotting.plotting import plot_yolo_boxes_on_image, plot_yolo_datum\n",
    "from torchsig.image_datasets.datasets.protocols import CFGSignalProtocolDataset, FrequencyHoppingDataset, random_hopping, YOLOFrequencyHoppingDataset, YOLOCFGSignalProtocolDataset\n",
    "\n",
    "from torchsig.datasets.narrowband import NewNarrowband, StaticNarrowband\n",
    "from torchsig.utils.writer import DatasetCreator\n",
    "from torchsig.datasets.dataset_metadata import NarrowbandMetadata\n",
    "import torchsig.transforms.dataset_transforms as ST\n",
    "import numpy as np\n",
    "\n",
    "from torchsig.image_datasets.dataset_generation import batched_write_yolo_synthetic_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9a375776-c294-4c66-bd54-9ee9911a29d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"figure.figsize\"] = (10,10) #increase the pyplot figure display size in the notebook for better visibility"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2a3b4a57-0b99-4f63-a60a-2e9ba1eca588",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "## Modeling Signals Using CFGSignalProtocolDataset\n",
    "\n",
    "A `CFGSignalProtocolDataset` is a dataset meant for modeling complicated relationships between individual components of a larger signal\n",
    "\n",
    "Here we define two functions: rising_chirp and falling_chirp, which return slanted forward and backwards diagonal lines respectively.\n",
    "We then combine those two basic components using a CFGSignalProtocolDataset to make a full signal shape.\n",
    "\n",
    "The CFGSignalProtocolDataset is an implementation of Context Free Grammar (CFG) logic for image datasets, such that each token of a string within the grammar is tied to either a dataset or a generator function which outputs image data. The result is an image composite containing all of the corresponding images side by side in order.\n",
    "We pass in a single argument to the CFGSignalProtocolDataset which defines it's start token.\n",
    "We then use the CFGSignalProtocolDataset.add_rule(token_in, token_out | list(token_out_0, token_out_1, ..... etc), relative_frequency) method to add in the logic o four CFG.\n",
    "\n",
    "Here, we are making a CFG starting on the token 'cfg_signal' and using the following rules:\n",
    "- `cfg_signal` will always evaluate as one 'rising_or_falling_stream' followed by 12 `rising_falling_or_null'\n",
    "- `rising_falling_or_null` will half the time evaluate as an empty string (ignored in the output image), and half the time as a 'rising_or_falling_stream'\n",
    "- `rising_or_falling_stream` will evaluate half the time as a 'rising_stream' and half the time as a `falling_stream`\n",
    "- `rising_stream` and `falling_stream` evaluate to between one and three `rising_segment` or `falling_segment' respectively\n",
    "- a `rising_segment` is three `rising_chirp`s \n",
    "- a `falling_segment` is three `falling_chirp`s\n",
    "- 'rising_chirp' and `falling_chirp` are evaluated using the `rising_chirp` and `falling_chirp` functions\n",
    "\n",
    "The result is fed into a `YOLODatasetAdapter` to return objects of type `YOLODatum` instead of plain images (this will be useful later), and plotted below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "78980771-86b3-44eb-bf27-e5c2715819ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_falling_edge(img_tnsr):\n",
    "    img_tnsr[:,:,-1] = 1\n",
    "    return img_tnsr\n",
    "\n",
    "rising_chirp = GeneratorFunctionDataset(chirp_generator_function(1, 20, 4, random_height_scale = [0.9,1.1], random_width_scale = [1,2]), \n",
    "                                        transforms=add_falling_edge)\n",
    "falling_chirp = GeneratorFunctionDataset(chirp_generator_function(1, 20, 4, random_height_scale = [0.9,1.1], random_width_scale = [1,2]), \n",
    "                                         transforms=[add_falling_edge, lambda x: x.flip(-1)])\n",
    "\n",
    "chirp_stream_ds = CFGSignalProtocolDataset('cfg_signal')\n",
    "chirp_stream_ds.add_rule('cfg_signal', ['rising_or_falling_stream'] + ['rising_falling_or_null']*12)\n",
    "chirp_stream_ds.add_rule('rising_falling_or_null', 'rising_or_falling_stream', 1)\n",
    "chirp_stream_ds.add_rule('rising_falling_or_null', 'null', 1)\n",
    "chirp_stream_ds.add_rule('null', None)\n",
    "chirp_stream_ds.add_rule('rising_or_falling_stream', 'rising_stream')\n",
    "chirp_stream_ds.add_rule('rising_or_falling_stream', 'falling_stream')\n",
    "chirp_stream_ds.add_rule('rising_stream', ['rising_segment'] + ['rising_segment_or_null']*2)\n",
    "chirp_stream_ds.add_rule('rising_segment_or_null', 'rising_segment')\n",
    "chirp_stream_ds.add_rule('rising_segment_or_null', 'null')\n",
    "chirp_stream_ds.add_rule('rising_segment', ['rising_chirp']*3)\n",
    "chirp_stream_ds.add_rule('rising_chirp', rising_chirp)\n",
    "chirp_stream_ds.add_rule('falling_stream', ['falling_segment'] + ['falling_segment_or_null']*2)\n",
    "chirp_stream_ds.add_rule('falling_segment_or_null', 'falling_segment')\n",
    "chirp_stream_ds.add_rule('falling_segment_or_null', 'null')\n",
    "chirp_stream_ds.add_rule('falling_segment', ['falling_chirp']*3)\n",
    "chirp_stream_ds.add_rule('falling_chirp', falling_chirp)\n",
    "\n",
    "yolo_chirp_stream_ds = YOLODatasetAdapter(chirp_stream_ds, class_id=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2455990a-d45c-4479-95b0-a4e5f5be35e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_yolo_datum(yolo_chirp_stream_ds[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "258873fa-bfe7-42d9-a463-727431b3e05b",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "Another kind of signal modeled using `CFGSignalProtocolDataset`, consisting of a fixed start and end sequence with between one and four 'bytes' of either large or small lines in the middle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b794a276-c75d-4dfe-8d15-271d56b2eccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "big_chirp_fn = chirp_generator_function(2, 8, 2, random_height_scale = [0.8,1.2], random_width_scale = [0.5,2])\n",
    "small_chirp_fn = chirp_generator_function(2, 4, 2, random_height_scale = [0.8,1.2], random_width_scale = [0.5,2])\n",
    "init_chirp_fn = chirp_generator_function(4, 2, 3, random_height_scale = [0.8,1.2], random_width_scale = [0.8,1.2])\n",
    "exit_chirp_fn = chirp_generator_function(2, 2, 2, random_height_scale = [0.8,1.2], random_width_scale = [0.8,1.2])\n",
    "\n",
    "bytes_ds = CFGSignalProtocolDataset('cfg_signal')\n",
    "bytes_ds.add_rule('cfg_signal', ['init_chirp', 'main_signal', 'exit_signal'])\n",
    "bytes_ds.add_rule('main_signal_chirp', ['bit_chirp']*8)\n",
    "bytes_ds.add_rule('bit_chirp', big_chirp_fn)\n",
    "bytes_ds.add_rule('bit_chirp', small_chirp_fn)\n",
    "bytes_ds.add_rule('main_signal', (['main_signal_chirp_or_null']*3)+['main_signal_chirp'])\n",
    "bytes_ds.add_rule('main_signal_chirp_or_null', 'main_signal_chirp', 1)\n",
    "bytes_ds.add_rule('main_signal_chirp_or_null', 'null', 1)\n",
    "bytes_ds.add_rule('null', None)\n",
    "bytes_ds.add_rule('init_chirp', init_chirp_fn)\n",
    "bytes_ds.add_rule('exit_signal', [exit_chirp_fn, exit_chirp_fn, exit_chirp_fn])\n",
    "\n",
    "yolo_bytes_ds = YOLODatasetAdapter(bytes_ds, class_id=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "800d5324-83b1-474e-8239-45e159e7e8c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_yolo_datum(yolo_bytes_ds[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3c19679-7dea-4864-8a54-b2661c12299d",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "## Modeling Signal Hopping using YOLOFrequencyHoppingDataset\n",
    "\n",
    "A `YOLOFrequencyHoppingDataset` is a dataset meant for modeling channel hops. it is passed a nmber of channels and a channel size (in pixels), and a max and min number of bursts to send.\n",
    "It will then generate images simulating a series of bursts along different channels.\n",
    "A hopping_function can be passed in, which determines the order in which channels are selected. By default they are selected in either ascending or descending order (randomly with 50/50 odds).\n",
    "\n",
    "Here we define two `YOLOFrequencyHoppingDataset` objects, both using bursts from the bytes_ds we defined above, but one hopping in linear order, and the other hopping randomly.\n",
    "\n",
    "We then define a YOLOCFGSignalProtocolDataset which combines them and returns one, the other, or both in either order.\n",
    "\n",
    "`YOLOCFGSignalProtocolDataset` works just like `CFGSignalProtocolDataset` above, except that the data it handles are of type `YOLODatum` and not images alone. This is done to track internally where the bounding box around each signal burst should go."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "54ce6ba0-c325-4f2b-8729-20c4e512cab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "hopper_ds = YOLOFrequencyHoppingDataset(yolo_bytes_ds, [20,80], 10, [100,100], [4,8])\n",
    "random_hopper_ds = YOLOFrequencyHoppingDataset(yolo_bytes_ds, [20,80], 10, [100,100], [4,8], hopping_function=random_hopping)\n",
    "\n",
    "two_mode_hopper = YOLOCFGSignalProtocolDataset('two_mode_hopping')\n",
    "two_mode_hopper.add_rule('two_mode_hopping',['12'])\n",
    "two_mode_hopper.add_rule('two_mode_hopping',['21'])\n",
    "two_mode_hopper.add_rule('two_mode_hopping',['1'])\n",
    "two_mode_hopper.add_rule('two_mode_hopping',['2'])\n",
    "two_mode_hopper.add_rule('12',['1','2'])\n",
    "two_mode_hopper.add_rule('21',['2','1'])\n",
    "two_mode_hopper.add_rule('1',hopper_ds)\n",
    "two_mode_hopper.add_rule('2',random_hopper_ds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c883df74-6816-4584-92b7-e2b86af562e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_yolo_datum(two_mode_hopper[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "701c114c-b7fe-404e-9c5f-1dc05722be3b",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "## Using Torchsig Narrrowband Datasets in Images\n",
    "\n",
    "Here we create a `ModulationsDataset` with a `SpectrogramImage` transform to make it return images of the generated data.\n",
    "\n",
    "This dataset is then passed into a `GeneratorFunctionDataset` with some filtering functions to clear the edges of the generated images.\n",
    "We also add some transforms to randomly resize andslightly blur the result. These are useful for image compositing.\n",
    "\n",
    "`GeneratorFunctionDataset` can take any zero-argument function which returns images and treat it as an image dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa13bed4-8ae6-4c91-9546-ebc198d8cc62",
   "metadata": {},
   "outputs": [],
   "source": [
    "def threshold_mod_signal(signal):#quick function for cleaning up narrowband spectrogram to look neat in composites; not strictly necessary, but very useful\n",
    "    signal[signal**4<0.02] = 0\n",
    "    return signal[:,:,5:-5] # cuts off the ends to look more seemless in composites\n",
    "\n",
    "mod_ds_transform = [\n",
    "    ST.Spectrogram(fft_size=128)\n",
    "]\n",
    "\n",
    "\n",
    "md = NarrowbandMetadata(\n",
    "    num_iq_samples_dataset = 2*128**2,\n",
    "    fft_size = 128,\n",
    "    impairment_level = 0,\n",
    "    num_samples = 10,\n",
    "    transforms = mod_ds_transform\n",
    ")\n",
    "ds = NewNarrowband(\n",
    "    dataset_metadata = md\n",
    ")\n",
    "\n",
    "dc = DatasetCreator(\n",
    "    dataset = ds,\n",
    "    root = \"./datasets/synthetic_spectrogram_example\",\n",
    "    overwrite = True\n",
    ")\n",
    "dc.create()\n",
    "\n",
    "mod_ds = StaticNarrowband(\n",
    "    root = \"./datasets/synthetic_spectrogram_example\",\n",
    "    impaired = False\n",
    ")\n",
    "\n",
    "signal_transforms = []\n",
    "signal_transforms += [normalize_image]\n",
    "signal_transforms += [RandomImageResizeTransform([0.02,1.5])]\n",
    "signal_transforms += [BlurTransform(strength=0.5, blur_shape=(5,1))]\n",
    "signal_transforms += [threshold_mod_signal]\n",
    "signal_transforms += [BlurTransform(strength=1, blur_shape=(5,1))]\n",
    "\n",
    "\n",
    "mod_spec_dataset = GeneratorFunctionDataset(lambda : 1 - mod_ds[np.random.randint(len(mod_ds))][0][None, :, :], transforms=signal_transforms)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c2b1fdf-4c5a-44e7-a16d-49552601163b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(mod_spec_dataset[0][0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed389479-d686-4a4d-a4a6-dfcbe5bd31e0",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "## Using YOLOImageCompositeDataset to put everything together\n",
    "\n",
    "Here we create a `YOLOImageCompositeDataset` which combines all of the datasets above into large composite simulated wideband spectrogram images.\n",
    "\n",
    "`YOLOImageCompositeDataset` takes as input a spectrogram_size, and has nothing in it at creation.\n",
    "Components must be added using `YOLOImageCompositeDataset.add_component(dataset, min_to_add, max_to_add)`.\n",
    "Every component will be added a uniform random number of times to the composites according to `min_to_add` and `max_to_add`.\n",
    "Components can be assigned a `class_id`, which will be the class id of the resulting `YOLODatum`, or they can be set to `use_source_yolo_labels`, in which case the composite will add in any yolo data that is retreaved from the component dataset. Components with no class_id will be added to the image, but unlabeled, and treated as background.\n",
    "\n",
    "Here we are adding the `yolo_chirp_stream_ds`, `bytes_ds`, `two_mode_hopper`, and `mod_spec_dataset` defined above to our `YOLOImageCompositeDataset` to simulate wideband data from an RF environment with many different types of transmitter present."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "365478a4-b398-495e-b1ce-873924b2e81e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def clamp_max_by_std(signal): # used to limit the dynamic range of the resulting image; stops overlapping image components from looking too dark or throwing off the rest of the image\n",
    "    signal[signal > signal.mean() + signal.std()*3] = signal.mean() + signal.std()*3\n",
    "    return signal\n",
    "\n",
    "spectrogram_size = (1,1024,1024)\n",
    "\n",
    "two_mode_hopper.transforms = [lambda x: YOLODatum(BlurTransform(strength=1, blur_shape=(5,1))(x.img), x.labels)] # add a little extra blur to help blend with the composite\n",
    "composite_transforms = []\n",
    "composite_transforms += [clamp_max_by_std] # limit dynamic range due to overlapping signals\n",
    "composite_transforms += [normalize_image] # inf norm\n",
    "composite_transforms += [RandomGaussianNoiseTransform(mean=0, range=(0.2,0.8))] # add background noise\n",
    "composite_transforms += [scale_dynamic_range]\n",
    "composite_transforms += [normalize_image] # inf norm\n",
    "composite_spectrogram_dataset = YOLOImageCompositeDataset(spectrogram_size, transforms=composite_transforms, dataset_size=250000, max_add=True)\n",
    "composite_spectrogram_dataset.add_component(yolo_chirp_stream_ds, min_to_add=0, max_to_add=3, use_source_yolo_labels=True)\n",
    "composite_spectrogram_dataset.add_component(bytes_ds, min_to_add=0, max_to_add=3, class_id=0)\n",
    "composite_spectrogram_dataset.add_component(two_mode_hopper, min_to_add=0, max_to_add=1, use_source_yolo_labels=True)\n",
    "composite_spectrogram_dataset.add_component(mod_spec_dataset, min_to_add=1, max_to_add=3, class_id=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb39271b-4e4c-4a6c-b42d-256ddbcaf2fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_datum = composite_spectrogram_dataset[0]\n",
    "plot_yolo_datum(sample_datum)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9344c4c-50bf-48b5-9c67-8766487fb131",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(sample_datum[0][0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
